import collections
import json
import os
import pickle
import traceback
from collections import OrderedDict

import numpy as np
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

from generic.data_util import load_event_data, ICEHOCKEY_ACTIONS, ICEHOCKEY_GAME_FEATURES, pad_sequence, \
    build_trace_mask, Transition, \
    read_feature_mean_scale, reverse_standard_data, QValueDiscretization, handle_gda_features, label_visiting_shrink, \
    entropy, read_features_within_events, divide_dataset_according2date, read_feature_max_min, SOCCER_ACTIONS, \
    SOCCER_GAME_FEATURES
from generic.gmm_util import gmm_fit
from density_model.maf_model import build_maf, validate_maf
from generic.model_util import to_pt, to_np
from layers.general_nn import ResBlock
from layers.q_nn import Spline_DQN, DQN


class SportsAgent(object):

    def __init__(self, config, log_file):
        self.log_file = log_file
        self.config = config
        self.read_config()
        self.device = 'cuda' if self.enable_cuda else 'cpu'
        if 'rl' in self.task:
            self.construct_model(log_file=log_file)
            if 'distrib' in self.task:
                build_maf(self)
        if self.sports == 'ice-hockey':
            self.data_maxs, self.data_mins = read_feature_max_min(data_dir='../icehockey-data/')
            self.data_means, self.data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
            with open(self.home_away_ids_dir, 'r') as f:
                self.home_away_game_ids = json.load(f)
        elif self.sports == 'soccer':
            self.data_maxs, self.data_mins = read_feature_max_min(data_dir='../soccer-data/')
            self.data_means, self.data_stds = read_feature_mean_scale(data_dir='../soccer-data/')
            self.home_away_game_ids = None
        else:
            raise ValueError("Unknown sports: {0}".format(self.sports))

    def read_config(self):
        self.task = self.config['general']['task']
        self.sports = self.config['general']['sports']
        self.train_data_path = self.config['general']['data']['data_path']
        self.source_data_dir = self.config['general']['data']['source_data_dir']
        self.home_away_ids_dir = self.config['general']['data']['home_away_ids_dir']
        self.train_rate = self.config['general']['training']['train_rate']
        self.enable_cuda = self.config['general']['use_cuda'] and torch.cuda.is_available()

        if 'rl' in self.task:
            self.gamma = self.config['general']['training']['gamma']
            self.batch_size = self.config['general']['training']['batch_size']
            self.max_episode = self.config['general']['training']['max_episode']
            self.cut_at_goal = self.config['general']['training']['cut_at_goal']
            self.keep_goal_state = self.config['general']['training']['keep_goal_state']
            self.learning_rate = self.config['general']['training']['learning_rate']
            self.apply_data_date_div = self.config['general']['training']['apply_data_date_div']
            self.input_dim = self.config['general']['model']['input_dim']

            # rnn hyper-parameters
            self.apply_rnn = self.config['general']['model']['apply_rnn']
            if self.apply_rnn:
                self.apply_dynamic_trace_length = self.config['general']['model']['apply_dynamic_trace_length']
            else:
                self.apply_dynamic_trace_length = self.config['general']['model']['apply_dynamic_trace_length']

            # resnet hyper-parameters
            self.apply_resnet = self.config['general']['model']['apply_resnet']
            if self.apply_resnet:
                self.resnet_layer_num = self.config['general']['model']['resnet_layer_num']
                self.split_state_action_resnet = self.config['general']['model']['split_state_action_resnet']
            else:
                self.resnet_layer_num = 0
                self.split_state_action_resnet = False

            self.num_tau = self.config['general']['model']['num_tau']
            self.num_supp = self.config['general']['model']['num_supp']
            self.block_hidden_dim = self.config['general']['model']['block_hidden_dim']
            self.max_trace_length = self.config['general']['model']['max_trace_length']

            self.report_frequency = self.config['general']['checkpoint']['report_frequency']
            self.update_target_frequency = self.config['general']['checkpoint']['update_target_frequency']
            self.save_model_dir = self.config['general']['checkpoint']['save_model_dir']
            self.experiment_tag = self.config['general']['checkpoint']['experiment_tag']

            if 'distrib' in self.task:
                # gda features
                self.gda_fitting_target = self.config['general']['gda']['gda_fitting_target']
                self.gda_apply_history = self.config['general']['gda']['apply_history']
                self.gda_apply_pd = self.config['general']['gda']['apply_pd']
                self.gda_discret_mode = self.config['general']['gda']['discret_mode']
                # maf features
                self.maf_flow_type = self.config['general']['maf']['flow_type']
                self.maf_apply_history = self.config['general']['maf']['apply_history']
                self.maf_num_blocks = self.config['general']['maf']['num_blocks']
                self.maf_num_inputs = self.config['general']['maf']['num_inputs']
                self.maf_num_hidden = self.config['general']['maf']['num_hidden']
                self.maf_lr = self.config['general']['maf']['learning_rate']
                self.maf_cond_act = self.config['general']['maf']['condition_on_action']
                self.maf_cond_value = self.config['general']['maf']['condtion_on_values']

                self.use_expectation_base = self.config['general']['model']['use_expectation_base']

                if self.sports == 'ice-hockey':
                    actions = ICEHOCKEY_ACTIONS
                elif self.sports == 'soccer':
                    actions = SOCCER_ACTIONS
                else:
                    raise ValueError("Unknown sports {0}".format(self.sports))
                if self.maf_cond_act:
                    self.maf_num_inputs -= len(actions)
                    # lof features
                self.maf_num_cond_inputs = None
                if self.maf_cond_act or self.maf_cond_value:
                    self.maf_num_cond_inputs = 0
                    if self.maf_cond_act:
                        self.maf_num_cond_inputs += len(actions)
                    if self.maf_cond_value:
                        self.maf_num_cond_inputs += 3
                self.lof_apply_history = self.config['general']['lof']['apply_history']
                self.lof_neighbors = self.config['general']['lof']['neighbors']
                self.lof_metric = self.config['general']['lof']['metric']
            else:
                self.use_expectation_base = None

    def construct_model(self, log_file):
        self.save_model_dict = {}
        # self.online_net = Spline_DQN(num_inputs=self.input_dim,
        #                              num_outputs=3,
        #                              num_support=16,
        #                              num_tau=16,
        #                              device=self.device).to(self.device)
        # self.target_net = Spline_DQN(num_inputs=self.input_dim,
        #                              num_outputs=3,
        #                              num_support=16,
        #                              num_tau=16,
        #                              device=self.device).to(self.device)
        if self.apply_rnn or self.apply_resnet:
            dqn_num_inputs = self.block_hidden_dim
        else:
            dqn_num_inputs = self.input_dim
        if 'train_distrib_rl' in self.task:
            # self.online_net = Spline_DQN_Single(num_inputs=self.input_dim,
            #                                     num_support=8,
            #                                     num_tau=16,
            #                                     device=self.device).to(self.device)
            # self.target_net = Spline_DQN_Single(num_inputs=self.input_dim,
            #                                     num_support=8,
            #                                     num_tau=16,
            #                                     device=self.device).to(self.device)
            self.online_net = Spline_DQN(num_inputs=dqn_num_inputs,
                                         num_outputs=3,
                                         num_support=self.num_supp,
                                         num_tau=self.num_tau,
                                         block_hidden_dim=self.block_hidden_dim,
                                         device=self.device).to(self.device)
            self.target_net = Spline_DQN(num_inputs=dqn_num_inputs,
                                         num_outputs=3,
                                         num_support=self.num_supp,
                                         num_tau=self.num_tau,
                                         block_hidden_dim=self.block_hidden_dim,
                                         device=self.device).to(self.device)
        elif 'train_rl' in self.task:
            self.online_net = DQN(num_inputs=dqn_num_inputs,
                                  block_hidden_dim=self.block_hidden_dim,
                                  num_outputs=3).to(self.device)
            self.target_net = DQN(num_inputs=dqn_num_inputs,
                                  block_hidden_dim=self.block_hidden_dim,
                                  num_outputs=3).to(self.device)
        self.save_model_dict.update({'online_net': self.online_net})

        if self.apply_rnn:
            self.rnn = nn.GRUCell(self.input_dim, self.block_hidden_dim).to(self.device)
            self.save_model_dict.update({'rnn': self.rnn})
        if self.apply_resnet:
            # self.resnet_all = []
            # if self.apply_resnet:
            #     total_resnet_num = self.max_trace_length
            # else:
            #     total_resnet_num = 1
            # for j in range(total_resnet_num):

            if self.split_state_action_resnet:
                state_resnet = torch.nn.Sequential()
                for i in range(self.resnet_layer_num):
                    state_resnet.add_module(name='State_ResBlock{0}'.format(i),
                                            module=ResBlock(num_inputs=self.input_dim - len(ICEHOCKEY_ACTIONS)))
                state_resnet.to(self.device)
                self.state_resnet = state_resnet
                self.save_model_dict.update({'state_resnet': self.state_resnet})

                action_resnet = torch.nn.Sequential()
                for i in range(self.resnet_layer_num):
                    action_resnet.add_module(name='Action_ResBlock{0}'.format(i),
                                             module=ResBlock(num_inputs=len(ICEHOCKEY_ACTIONS)))
                action_resnet.to(self.device)
                self.action_resnet = action_resnet
                self.save_model_dict.update({'action_resnet': self.action_resnet})
            else:
                resnet = torch.nn.Sequential()
                for i in range(self.resnet_layer_num):
                    resnet.add_module(name='ResBlock{0}'.format(i),
                                      module=ResBlock(num_inputs=self.input_dim))
                resnet.to(self.device)
                self.resnet = resnet
                self.save_model_dict.update({'resnet': self.resnet})
            # self.save_model_dict.update({'resnet_no.{0}'.format(j): resnet})
            # self.resnet_all.append(resnet)
        print("--------------------------------------------Model Settings--------------------------------------------",
              file=log_file, flush=True)
        feature_extractions = []
        if self.apply_rnn:
            feature_extractions.append('RNN')
        if self.apply_resnet:
            feature_extractions.append('Resnet')
        feature_extraction_str = ' + '.join(feature_extractions) if len(feature_extractions) > 0 else 'MLP'
        print("Feature Extraction Layers: {0}".format(feature_extraction_str), file=log_file, flush=True)
        print("Value Estimation Layers: {0}".format('Distributional DQN' if 'distrib' in self.task else 'DQN'),
              file=log_file, flush=True)
        print("--------------------------------------------End--------------------------------------------",
              file=log_file, flush=True)

        self.param_frozen_list = []
        self.frozen_parameters_keys = []
        self.param_active_list = []
        self.active_parameters_keys = []

        param_frozen_list, param_active_list = \
            self.handle_model_parameters(fix_keywords=[],
                                         model_name=self.task,
                                         log_file=log_file)
        self.optimizer = optim.Adam([{'params': param_frozen_list, 'lr': 0.0},
                                     {'params': param_active_list, 'lr': self.learning_rate}],
                                    lr=self.learning_rate)
        self.all_gda_models = None

    def match_model_parameters(self, params_list, parameters_info, fix_keywords):
        for k, v in params_list:
            keep_this = True
            size = torch.numel(v)
            parameters_info.append("{0}:{1}".format(k, size))
            for keyword in fix_keywords:
                if keyword in k:
                    self.param_frozen_list.append(v)
                    v.requires_grad = False
                    keep_this = False
                    self.frozen_parameters_keys.append(k)
                    break
            if keep_this:
                self.param_active_list.append(v)
                self.active_parameters_keys.append(k)

    def handle_model_parameters(self, fix_keywords, model_name, log_file):
        """determine which parameters should be fixed"""
        parameters_info = []

        self.match_model_parameters(params_list=self.online_net.named_parameters(),
                                    parameters_info=parameters_info,
                                    fix_keywords=fix_keywords)
        if self.apply_rnn:
            self.match_model_parameters(params_list=self.rnn.named_parameters(),
                                        parameters_info=parameters_info,
                                        fix_keywords=fix_keywords)
        if self.apply_resnet:
            # for resnet in self.resnet_all:
            if self.split_state_action_resnet:
                self.match_model_parameters(params_list=self.state_resnet.named_parameters(),
                                            parameters_info=parameters_info,
                                            fix_keywords=fix_keywords)
                self.match_model_parameters(params_list=self.action_resnet.named_parameters(),
                                            parameters_info=parameters_info,
                                            fix_keywords=fix_keywords)
            else:
                self.match_model_parameters(params_list=self.resnet.named_parameters(),
                                            parameters_info=parameters_info,
                                            fix_keywords=fix_keywords)
        print('-' * 30 + '{0} Optimizer'.format(model_name) + '-' * 30, file=log_file, flush=True)
        print("Fixed parameters are: {0}".format(str(self.frozen_parameters_keys)), file=log_file, flush=True)
        print("Active parameters are: {0}".format(str(self.active_parameters_keys)), file=log_file, flush=True)
        # print(parameters_info, file=log_file, flush=True)
        param_frozen_list = torch.nn.ParameterList(self.param_frozen_list)
        param_active_list = torch.nn.ParameterList(self.param_active_list)
        print('-' * 60, file=log_file, flush=True)

        return param_frozen_list, param_active_list

    def load_sports_data(self, game_label, need_check=True, sanity_check_msg=None):
        # 'xAdjCoord', 'yAdjCoord', 'scoreDifferential', 'Penalty', 'duration', 'velocity_x', 'velocity_y',
        # 'time_remained', 'event_outcome', 'home', 'away', 'angel2gate'
        s_a_data = load_event_data(data_path=self.train_data_path + '/' + game_label + '/state-action-data.pkl')
        if sanity_check_msg is not None:
            # sanity_check_location_ha_, sanity_check_sd_md_tr_ha_
            if 'location' in sanity_check_msg and 'ha' in sanity_check_msg:
                s_a_data = np.concatenate([s_a_data[:, :2], s_a_data[:, 9:11], s_a_data[:, -len(ICEHOCKEY_ACTIONS):]], axis=1)
            elif 'sd' in sanity_check_msg and 'md' in sanity_check_msg and 'ha' in sanity_check_msg:
                s_a_data = np.concatenate([s_a_data[:, 2:4], s_a_data[:, 7:8],
                                           s_a_data[:, 9:11], s_a_data[:, -len(ICEHOCKEY_ACTIONS):]], axis=1)
            elif 'no_action' in sanity_check_msg:
                s_a_data = s_a_data[:, :-len(ICEHOCKEY_ACTIONS)]
        # s_a_data = np.concatenate([s_a_data[:, :2], s_a_data[:, -len(ACTIONS):]], axis=1)
        if need_check:
            assert s_a_data.shape[1] == self.input_dim
        r_data = load_event_data(data_path=self.train_data_path + '/' + game_label + '/reward-data.pkl')
        assert len(s_a_data) == len(r_data)
        return s_a_data, r_data

    def load_player_id(self, game_label):
        read_info_all = read_features_within_events(game_label=game_label,
                                                    source_data_dir=self.source_data_dir,
                                                    feature_name_list=['playerId'],
                                                    sports=self.sports)
        player_id = [info['playerId'] for info in read_info_all]
        return player_id

    def load_team_id(self, game_label):
        read_info_all = read_features_within_events(game_label=game_label,
                                                    source_data_dir=self.source_data_dir,
                                                    feature_name_list=['teamId'],
                                                    sports=self.sports)
        team_id = [info['teamId'] for info in read_info_all]
        return team_id

    def build_transitions(self, s_a_data, r_data, pid_sequence):
        curr_step_s_a = to_pt(s_a_data[0], enable_cuda=self.enable_cuda, type='float')
        curr_step_r_data = r_data[0]
        transition_all = []

        skip = False
        next_done = 0
        event_length = len(s_a_data) + 1 if self.keep_goal_state else len(s_a_data)
        for event_num in range(1, event_length):
            # if event_num == len(s_a_data) - 1:
            #     print("find u")
            # event_num = event_num if event_num < len(s_a_data) else len(s_a_data) - 1
            # the ending event has no next event
            if skip and self.keep_goal_state:
                next_step_s_a = curr_step_s_a
                next_step_r_data = curr_step_r_data
                next_done = 1
            else:
                next_step_s_a = to_pt(s_a_data[event_num], enable_cuda=self.enable_cuda, type='float')
                next_step_r_data = r_data[event_num]

            if event_num < len(s_a_data) and (r_data[event_num][0] or
                                              r_data[event_num][1] or
                                              r_data[event_num][2]):
                if self.cut_at_goal:
                    done = 0 if self.keep_goal_state else 1
                    skip = True
                elif next_step_r_data[2]:
                    done = 0 if self.keep_goal_state else 1
                    skip = True
                else:
                    done = 0
                    skip = False
                # print(event_num, done)
            else:
                done = next_done if self.keep_goal_state else 0
                next_done = 0
                skip = False

            check_step_r_data = curr_step_r_data if self.keep_goal_state else next_step_r_data
            if curr_step_s_a is not None:
                pid = pid_sequence[event_num] if event_num < len(s_a_data) else pid_sequence[event_num - 1]
                transition = Transition(curr_step_s_a, None, next_step_s_a, None,
                                        check_step_r_data[0], check_step_r_data[1], check_step_r_data[2],
                                        pid,
                                        done)
                transition_all.append(transition)

            if event_num < len(s_a_data):
                if skip and not self.keep_goal_state:  # break the trace !
                    curr_step_s_a = None
                    curr_step_r_data = None
                else:
                    if done and self.keep_goal_state:
                        curr_step_s_a = to_pt(s_a_data[event_num], enable_cuda=self.enable_cuda, type='float')
                        curr_step_r_data = r_data[event_num]
                    else:
                        curr_step_s_a = next_step_s_a
                        curr_step_r_data = next_step_r_data

        return transition_all

    def build_rnn_transitions(self, s_a_data, r_data, pid_sequence):
        if self.sports == 'ice-hockey':
            all_features = ICEHOCKEY_GAME_FEATURES + ICEHOCKEY_ACTIONS
        elif self.sports == 'soccer':
            all_features = SOCCER_GAME_FEATURES + SOCCER_ACTIONS
        else:
            raise ValueError("Unknown sports {0}".format(self.sports))
        pre_step_team = s_a_data[0][all_features.index('home')]
        curr_sequence_step_s_a = [s_a_data[0]]
        curr_sequence_step_r_data = [r_data[0]]
        transition_all = []
        skip = False
        next_done = 0
        event_length = len(s_a_data) + 1 if self.keep_goal_state else len(
            s_a_data)  # the ending event has no next event
        for event_num in range(1, event_length):
            # event_num = event_num if event_num < len(s_a_data) else len(
            #     s_a_data) - 1
            # if event_num == len(s_a_data) - 1:
            #     print("find u")
            # raise ValueError("Could be a bug at the last step")
            if event_num < len(s_a_data):
                curr_step_team = s_a_data[event_num][all_features.index('home')]
            if skip and self.keep_goal_state:  # the goal event is cut, will not move to the next state
                next_sequence_step_s_a = curr_sequence_step_s_a
                next_sequence_step_r_data = curr_sequence_step_r_data
                next_done = 1
            else:
                if self.apply_dynamic_trace_length:
                    if curr_step_team == pre_step_team:
                        next_sequence_step_s_a = curr_sequence_step_s_a + [s_a_data[event_num]]
                        next_sequence_step_r_data = curr_sequence_step_r_data + [r_data[event_num]]
                    else:
                        next_sequence_step_s_a = [s_a_data[event_num]]
                        next_sequence_step_r_data = [r_data[event_num]]
                else:
                    next_sequence_step_s_a = curr_sequence_step_s_a + [s_a_data[event_num]]
                    next_sequence_step_r_data = curr_sequence_step_r_data + [r_data[event_num]]

                if len(next_sequence_step_s_a) > self.max_trace_length:
                    next_sequence_step_s_a = next_sequence_step_s_a[-self.max_trace_length:]
                    next_sequence_step_r_data = next_sequence_step_r_data[-self.max_trace_length:]

            # check_event_idx = event_num-1 if self.keep_goal_state else event_num  # event_num-1: current,event_num: next
            check_event_idx = event_num
            check_seq_step_r_data = curr_sequence_step_r_data if self.keep_goal_state else next_sequence_step_r_data
            if event_num < len(s_a_data) and (r_data[check_event_idx][0] or
                                              r_data[check_event_idx][1] or
                                              r_data[check_event_idx][2]):
                if self.cut_at_goal:
                    done = 0 if self.keep_goal_state else 1
                    skip = True
                elif next_sequence_step_r_data[-1][2]:  # the game is ending
                    done = 0 if self.keep_goal_state else 1
                    skip = True
                else:
                    done = 0
                    skip = False
                # print(event_num, r_data[event_num], done)
            else:
                done = next_done if self.keep_goal_state else 0
                next_done = 0
                skip = False

            if len(curr_sequence_step_s_a) > 0 and len(curr_sequence_step_r_data) > 0:  # skip the place we break
                pid = pid_sequence[event_num] if event_num < len(s_a_data) else pid_sequence[event_num - 1]
                transition = Transition(curr_sequence_step_s_a,
                                        len(curr_sequence_step_s_a),
                                        next_sequence_step_s_a,
                                        len(next_sequence_step_s_a),
                                        [step_r_data[0] for step_r_data in check_seq_step_r_data],
                                        [step_r_data[1] for step_r_data in check_seq_step_r_data],
                                        [step_r_data[2] for step_r_data in check_seq_step_r_data],
                                        pid,
                                        done)
                transition_all.append(transition)
                assert len(next_sequence_step_r_data) == len(next_sequence_step_s_a)

            if event_num < len(s_a_data):
                if skip and not self.keep_goal_state:  # break the trace !
                    curr_sequence_step_s_a = []
                    curr_sequence_step_r_data = []
                else:
                    if done and self.keep_goal_state:
                        curr_sequence_step_s_a = [s_a_data[event_num]]
                        curr_sequence_step_r_data = [r_data[event_num]]
                    else:
                        curr_sequence_step_s_a = next_sequence_step_s_a
                        curr_sequence_step_r_data = next_sequence_step_r_data
                pre_step_team = curr_step_team

        # print(transition_all[-1])
        # print(r_data[-1])

        transition_all_pad = []
        for transition in transition_all:
            transition_pad = Transition(to_pt(pad_sequence(sequence=transition.state_action,
                                                           max_length=self.max_trace_length),
                                              enable_cuda=self.enable_cuda, type='float'),
                                        transition.trace,
                                        to_pt(pad_sequence(sequence=transition.next_state_action,
                                                           max_length=self.max_trace_length),
                                              enable_cuda=self.enable_cuda, type='float'),
                                        transition.next_trace,
                                        transition.reward_h + [0 for i in range(self.max_trace_length - len(
                                            transition.reward_h))],
                                        transition.reward_a + [0 for i in range(self.max_trace_length - len(
                                            transition.reward_a))],
                                        transition.reward_n + [0 for i in range(self.max_trace_length - len(
                                            transition.reward_n))],
                                        transition.pid,
                                        transition.done)
            transition_all_pad.append(transition_pad)

        return transition_all_pad

    def select_data_by_action(self, transition_all, sanity_check_msg):
        transition_all_selected = []
        for i in range(len(transition_all)):
            state_action_data = transition_all[i].state_action[transition_all[i].trace - 1]
            state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                        data_means=self.data_means,
                                                        data_stds=self.data_stds,
                                                        sanity_check_msg=sanity_check_msg)
            action = None
            max_action_label = 0
            for candidate_action in ICEHOCKEY_ACTIONS:  # check which action is performed
                if state_action_origin[candidate_action] > max_action_label:
                    max_action_label = state_action_origin[candidate_action]
                    action = candidate_action
            if action == 'shot':
                transition_all_selected.append(transition_all[i])
                # print(state_action_origin['xAdjCoord'], state_action_origin['yAdjCoord'])
                # print(state_action_data[:2])
        return transition_all_selected

    def get_transition_batch(self, transition_all, counter):
        return Transition(*zip(*transition_all[counter * self.batch_size:(counter + 1) * self.batch_size]))

    def update_target_net(self):
        # print('updating target net', file=self.log_file, flush=True)
        self.target_net.load_state_dict(self.online_net.state_dict())

    def update_dqn_model(self, batch):
        state_actions = torch.stack(batch.state_action).to(self.device)
        next_state_actions = torch.stack(batch.next_state_action).to(self.device)  # next_state + next_action
        rewards_h = torch.Tensor(batch.reward_h).to(self.device)
        rewards_a = torch.Tensor(batch.reward_a).to(self.device)
        rewards_n = torch.Tensor(batch.reward_n).to(self.device)
        batch_size = len(state_actions)

        # if rewards_h.sum() > 0 or rewards_a.sum() > 0:
        #     print('debug')

        if self.apply_rnn:
            online_hidden_states = None
            trace_mask = build_trace_mask(trace=batch.trace, max_trace_length=self.max_trace_length)
            q_values_rnn = []
            for i in range(self.max_trace_length):
                rnn_input = state_actions[:, i, :]
                online_hidden_states = self.rnn(input=rnn_input, hx=online_hidden_states)
                q_values_step = self.online_net(state_action=online_hidden_states)
                q_values_rnn.append(q_values_step)

            q_values_rnn = torch.stack(q_values_rnn, dim=1)
            trace_mask = np.expand_dims(trace_mask, axis=(2)).repeat(3, axis=2)
            trace_mask = to_pt(trace_mask, enable_cuda=self.enable_cuda)
            q_values = torch.sum(q_values_rnn * trace_mask, dim=1)

            # compute the target
            target_hidden_states = None
            next_trace_mask = build_trace_mask(trace=batch.next_trace, max_trace_length=self.max_trace_length)
            next_q_values_rnn = []
            for i in range(self.max_trace_length):
                # compute the current
                rnn_input = next_state_actions[:, i, :]
                target_hidden_states = self.rnn(input=rnn_input, hx=target_hidden_states)
                next_q_values_step = self.target_net(state_action=target_hidden_states)
                next_q_values_rnn.append(next_q_values_step)
            next_q_values_rnn = torch.stack(next_q_values_rnn, dim=1)
            next_trace_mask = np.expand_dims(next_trace_mask, axis=(2)).repeat(3, axis=2)
            next_trace_mask = to_pt(next_trace_mask, enable_cuda=self.enable_cuda)
            next_q_values = torch.sum(next_q_values_rnn * next_trace_mask, dim=1)

            rewards_h = torch.stack([rewards_h[bid, batch.next_trace[bid] - 1] for bid in range(batch_size)])
            rewards_a = torch.stack([rewards_a[bid, batch.next_trace[bid] - 1] for bid in range(batch_size)])
            rewards_n = torch.stack([rewards_n[bid, batch.next_trace[bid] - 1] for bid in range(batch_size)])
        else:
            q_values = self.online_net(state_actions)
            next_q_values = self.online_net(next_state_actions)  # (batch_size, num_tau)
        rewards = torch.stack([rewards_h, rewards_a, rewards_n], dim=1)
        if torch.sum(rewards) != 0:
            print('score')
        dones = torch.Tensor(batch.done).to(self.device)  # done
        bellman_target = rewards + self.gamma * (1.0 - dones.unsqueeze(1)) * next_q_values

        if torch.sum(rewards_h) > 0 or torch.sum(rewards_a) > 0 or torch.sum(rewards_n) > 0:
            assert torch.sum(dones) > 0
            # assert torch.sum(dones) == torch.sum(rewards)

        loss = torch.square(q_values - bellman_target.detach())
        loss = torch.mean(loss)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss

    def compute_distrib_dqn_loss(self,
                                 theta, next_theta,
                                 rewards_h, rewards_a, rewards_n,
                                 dones, num_tau,
                                 if_loss_mean=True):

        # ------------ for home -----------
        theta_h = theta[:, 0, :]
        next_theta_h = next_theta[:, 0, :]
        # tmp = self.gamma * (1.0 - dones.unsqueeze(1))
        T_theta_h = rewards_h.unsqueeze(1) + self.gamma * (1.0 - dones.unsqueeze(1)) * next_theta_h

        # compute quantile huber loss
        T_theta_h_tile = T_theta_h.view(-1, num_tau, 1).expand(-1, num_tau, num_tau)  # target
        theta_h_tile = theta_h.view(-1, 1, num_tau).expand(-1, num_tau, num_tau)  # current

        tau = torch.arange(0.5 * (1 / num_tau), 1, 1 / num_tau).view(1, num_tau).to(self.device)

        error_loss = T_theta_h_tile - theta_h_tile
        huber_loss = F.smooth_l1_loss(theta_h_tile, T_theta_h_tile.detach(), reduction='none')

        loss_h = (tau - (error_loss < 0).float()).abs() * huber_loss
        loss_h = loss_h.mean(dim=2).sum(dim=1)
        if if_loss_mean:
            loss_h = loss_h.mean()

        # ------------ for away -------------
        theta_a = theta[:, 1, :]
        next_theta_a = next_theta[:, 1, :]
        T_theta_a = rewards_a.unsqueeze(1) + self.gamma * (1.0 - dones.unsqueeze(1)) * next_theta_a

        # compute quantile huber loss
        T_theta_a_tile = T_theta_a.view(-1, num_tau, 1).expand(-1, num_tau, num_tau)  # target
        theta_a_tile = theta_a.view(-1, 1, num_tau).expand(-1, num_tau, num_tau)  # current

        tau = torch.arange(0.5 * (1 / num_tau), 1, 1 / num_tau).view(1, num_tau).to(self.device)

        error_loss = T_theta_a_tile - theta_a_tile
        huber_loss = F.smooth_l1_loss(theta_a_tile, T_theta_a_tile.detach(), reduction='none')

        loss_a = (tau - (error_loss < 0).float()).abs() * huber_loss
        loss_a = loss_a.mean(dim=2).sum(dim=1)
        if if_loss_mean:
            loss_a = loss_a.mean()

        # ------------ for neither -------------
        theta_n = theta[:, 2, :]
        next_theta_n = next_theta[:, 2, :]
        T_theta_n = rewards_n.unsqueeze(1) + self.gamma * (1.0 - dones.unsqueeze(1)) * next_theta_n

        # compute quantile huber loss
        T_theta_n_tile = T_theta_n.view(-1, num_tau, 1).expand(-1, num_tau, num_tau)  # target
        theta_n_tile = theta_n.view(-1, 1, num_tau).expand(-1, num_tau, num_tau)  # current

        tau = torch.arange(0.5 * (1 / num_tau), 1, 1 / num_tau).view(1, num_tau).to(self.device)

        error_loss = T_theta_n_tile - theta_n_tile
        huber_loss = F.smooth_l1_loss(theta_n_tile, T_theta_n_tile.detach(), reduction='none')

        loss_n = (tau - (error_loss < 0).float()).abs() * huber_loss
        loss_n = loss_n.mean(dim=2).sum(dim=1)
        if if_loss_mean:
            loss_n = loss_n.mean()
        loss = (loss_h + loss_a + loss_n) / 3.

        return loss

    def update_distrib_dqn_model(self, num_tau, batch):
        '''
        num_tau: number of quantiles for quantile regression, should be consistant with model's num_tau
        '''
        state_actions = torch.stack(batch.state_action).to(self.device)  # s_t, a_t
        next_state_actions = torch.stack(batch.next_state_action).to(self.device)  # s_t+1, a_t+1
        rewards_h = torch.Tensor(batch.reward_h).to(self.device)  # r_home_t+1
        rewards_a = torch.Tensor(batch.reward_a).to(self.device)  # r_home_t+1
        rewards_n = torch.Tensor(batch.reward_n).to(self.device)  # r_home_t+1
        dones = torch.Tensor(batch.done).to(self.device)  # done
        batch_size = len(state_actions)

        if self.apply_rnn:
            online_hidden_states = None
            trace_mask = build_trace_mask(trace=batch.trace, max_trace_length=self.max_trace_length)
            theta_all = []
            for i in range(self.max_trace_length):
                # compute the current
                if self.apply_resnet:
                    if self.split_state_action_resnet:
                        res_output_state = self.state_resnet(state_actions[:, i, :-len(ICEHOCKEY_ACTIONS)])
                        res_output_action = self.action_resnet(state_actions[:, i, -len(ICEHOCKEY_ACTIONS):])
                        rnn_input = torch.cat([res_output_state, res_output_action], dim=1)
                    else:
                        rnn_input = self.resnet(state_actions[:, i, :])
                else:
                    rnn_input = state_actions[:, i, :]
                online_hidden_states = self.rnn(input=rnn_input, hx=online_hidden_states)
                theta_step = self.online_net(state_action=online_hidden_states)
                theta_all.append(theta_step)

            theta_all = torch.stack(theta_all, dim=1)
            trace_mask = np.expand_dims(trace_mask, axis=(2, 3)).repeat(3, axis=2).repeat(self.num_tau, axis=3)
            trace_mask = to_pt(trace_mask, enable_cuda=self.enable_cuda)
            theta = torch.sum(theta_all * trace_mask, dim=1)

            # compute the target
            target_hidden_states = None
            next_trace_mask = build_trace_mask(trace=batch.next_trace, max_trace_length=self.max_trace_length)
            next_theta_all = []
            for i in range(self.max_trace_length):
                # compute the current
                if self.apply_resnet:
                    if self.split_state_action_resnet:
                        res_output_state = self.state_resnet(next_state_actions[:, i, :-len(ICEHOCKEY_ACTIONS)])
                        res_output_action = self.action_resnet(next_state_actions[:, i, -len(ICEHOCKEY_ACTIONS):])
                        rnn_input = torch.cat([res_output_state, res_output_action], dim=1)
                    else:
                        rnn_input = self.resnet(next_state_actions[:, i, :])
                else:
                    rnn_input = next_state_actions[:, i, :]
                target_hidden_states = self.rnn(input=rnn_input, hx=target_hidden_states)
                next_theta_step = self.target_net(state_action=target_hidden_states)
                next_theta_all.append(next_theta_step)
            next_theta_all = torch.stack(next_theta_all, dim=1)
            next_trace_mask = np.expand_dims(next_trace_mask, axis=(2, 3)).repeat(3, axis=2).repeat(self.num_tau,
                                                                                                    axis=3)
            next_trace_mask = to_pt(next_trace_mask, enable_cuda=self.enable_cuda)
            next_theta = torch.sum(next_theta_all * next_trace_mask, dim=1)

            rewards_h = torch.stack([rewards_h[bid, batch.next_trace[bid] - 1] for bid in range(batch_size)])
            rewards_a = torch.stack([rewards_a[bid, batch.next_trace[bid] - 1] for bid in range(batch_size)])
            rewards_n = torch.stack([rewards_n[bid, batch.next_trace[bid] - 1] for bid in range(batch_size)])
            # if torch.sum(rewards_h + rewards_a + rewards_n) > 0:
            #     print('debug')
            #     data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
            #     state_action_data = state_actions[0, 0, :]
            #     state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
            #                                                 data_means=data_means,
            #                                                 data_stds=data_stds)
            #     state_action_data = next_state_actions[0, 0, :]
            #     next_state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
            #                                                      data_means=data_means,
            #                                                      data_stds=data_stds)
        else:
            if self.apply_resnet:
                if self.split_state_action_resnet:
                    res_output_state = self.state_resnet(state_actions[:, :-len(ICEHOCKEY_ACTIONS)])
                    res_output_action = self.action_resnet(state_actions[:, -len(ICEHOCKEY_ACTIONS):])
                    state_actions = torch.cat([res_output_state, res_output_action], dim=1)
                    res_output_next_state = self.state_resnet(next_state_actions[:, :-len(ICEHOCKEY_ACTIONS)])
                    res_output_next_action = self.action_resnet(next_state_actions[:, -len(ICEHOCKEY_ACTIONS):])
                    next_state_actions = torch.cat([res_output_next_state, res_output_next_action], dim=1)
                else:
                    state_actions = self.resnet(state_actions)
                    next_state_actions = self.resnet(next_state_actions)
            # compute the current
            theta = self.online_net(state_actions)
            # compute the target
            next_theta = self.target_net(next_state_actions)  # (batch_size, 3, num_tau)
        # if torch.sum(dones) > 0:
        #     print('abc')

        if torch.sum(rewards_h) > 0 or torch.sum(rewards_a) > 0 or torch.sum(rewards_n) > 0:
            assert torch.sum(dones) > 0

        loss = self.compute_distrib_dqn_loss(theta=theta,
                                             next_theta=next_theta,
                                             rewards_h=rewards_h,
                                             rewards_a=rewards_a,
                                             rewards_n=rewards_n,
                                             dones=dones,
                                             num_tau=num_tau)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # parameters_info = []
        # for k, v in self.online_net.named_parameters():
        #     if v.grad is not None:
        #         parameters_info.append("{0}:{1}".format(k, torch.mean(v.grad)))
        #     else:
        #         parameters_info.append("{0}:{1}".format(k, v.grad))
        # if self.apply_rnn:
        #     for k, v in self.rnn.named_parameters():
        #         if v.grad is not None:
        #             parameters_info.append("{0}:{1}".format(k, torch.mean(v.grad)))
        #         else:
        #             parameters_info.append("{0}:{1}".format(k, v.grad))
        # if self.apply_resnet:
        #     for k, v in self.resnet.named_parameters():
        #         if v.grad is not None:
        #             parameters_info.append("{0}:{1}".format(k, torch.mean(v.grad)))
        #         else:
        #             parameters_info.append("{0}:{1}".format(k, v.grad))
        # print(parameters_info)
        return loss

    def save_maf(self, save_to_path, episode_no, avg_log_prob, avg_loss, log_file):
        torch.save({
            'epoch': episode_no,
            'model_state_dict': self.maf_model.state_dict(),
            'optimizer_state_dict': self.maf_optim.state_dict(),
            'avg_log_prob': avg_log_prob,
            'avg_loss': avg_loss,
        }, save_to_path)
        print("Saved checkpoint to %s" % (save_to_path), file=log_file, flush=True)

    def load_maf(self, load_from, log_file=None):
        print("loading model from %s" % (load_from), file=log_file, flush=True)
        try:
            if self.enable_cuda:
                checkpoint = torch.load(load_from)
            else:
                checkpoint = torch.load(load_from, map_location='cpu')
            model_checkpoints = checkpoint['model_state_dict']
            model_component_dict = self.maf_model.state_dict()
            component_pretrained_dict = {}
            load_keys = []
            omit_keys = []
            for k, v in model_checkpoints.items():
                if k in model_component_dict:
                    load_keys.append(k)
                    component_pretrained_dict.update({k: v})
                else:
                    omit_keys.append(k)
            component_pretrained_dict = {k: v for k, v in model_checkpoints.items() if k in model_component_dict}

            model_component_dict.update(component_pretrained_dict)
            self.maf_model.load_state_dict(model_component_dict)
            optimizer_checkpoints = checkpoint['optimizer_state_dict']
            self.maf_optim.load_state_dict(optimizer_checkpoints)
            load_avg_log_prob = checkpoint['avg_log_prob']
            load_avg_loss = checkpoint['avg_loss']

            print("Loaded model parameters are:" + ", ".join(load_keys), file=log_file, flush=True)
            print("Omitted model parameters are:" + ", ".join(omit_keys), file=log_file, flush=True)
            print("--------------------------\n", file=log_file, flush=True)

            return load_avg_loss, load_avg_log_prob
        except Exception:
            traceback.print_exc()
            print("Failed to load checkpoint...\n", file=log_file, flush=True)
            return 0, 0

    def save_model_to_path(self, save_to_path, episode_no,
                           eval_distance=None, std_distance=None, coorelation=None, log_file=None):
        saved_state_dict = OrderedDict()
        for model_name in self.save_model_dict.keys():
            saved_state_dict.update(self.save_model_dict[model_name].state_dict())
        torch.save({
            'epoch': episode_no,
            'model_state_dict': saved_state_dict,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'eval_distance': eval_distance,
            'std_distance': std_distance,
            'avg_correl': coorelation,
        }, save_to_path)
        print("Saved checkpoint to %s" % (save_to_path), file=log_file, flush=True)

    def load_pretrained_model(self, load_from, load_optim=True, log_file=None):
        """
        Load pretrained checkpoint from file.

        Arguments:
            :param load_from: model dir
            :param load_optim: if load optimizer
            :param log_file: log file
        """
        print("loading model from %s" % (load_from), file=log_file, flush=True)
        try:
            if self.enable_cuda:
                checkpoint = torch.load(load_from)
            else:
                checkpoint = torch.load(load_from, map_location='cpu')

            model_checkpoints = checkpoint['model_state_dict']

            model_checkpoints_update = collections.OrderedDict()
            for key in model_checkpoints.keys():
                if 'rnn' in key:
                    model_checkpoints_update.update({key.replace('rnn.', ''): model_checkpoints[key]})
                else:
                    model_checkpoints_update.update({key: model_checkpoints[key]})
            model_checkpoints = model_checkpoints_update

            optimizer_checkpoints = checkpoint['optimizer_state_dict']
            episode_num = checkpoint['epoch']
            if 'eval_distance' in checkpoint.keys():
                eval_distance = checkpoint['eval_distance']
            else:
                eval_distance = float('inf')
            if 'std_distance' in checkpoint.keys():
                std_distance = checkpoint['std_distance']
            else:
                std_distance = float('inf')
            if 'avg_correl' in checkpoint.keys():
                avg_correl = checkpoint['avg_correl']
            else:
                avg_correl = -float('inf')
            # acc = checkpoint['eval_acc']
            model_dict = OrderedDict()
            for model_name in self.save_model_dict.keys():
                model_component_dict = self.save_model_dict[model_name].state_dict()
                component_pretrained_dict = {k: v for k, v in model_checkpoints.items() if k in model_component_dict}
                model_component_dict.update(component_pretrained_dict)
                self.save_model_dict[model_name].load_state_dict(model_component_dict)
                model_dict.update(model_component_dict)
            pretrained_dict = {k: v for k, v in model_checkpoints.items() if k in model_dict}
            omitted_dict = {k: v for k, v in model_checkpoints.items() if k not in model_dict}
            # model_dict.update(pretrained_dict)
            # self.online_net.load_state_dict(model_dict)
            if load_optim:
                self.optimizer.load_state_dict(optimizer_checkpoints)
            loss = float('inf') if eval_distance is None else eval_distance
            print("Successfully load model with epoch:{0}, "
                  "eval_distance:{1}, std_distance:{2}, avg_correl:{3}".
                  format(episode_num, eval_distance, std_distance, avg_correl), file=log_file, flush=True)
            load_keys = [key for key in pretrained_dict]
            omit_keys = [key for key in omitted_dict]
            print("Loaded model parameters are:" + ", ".join(load_keys), file=log_file, flush=True)
            print("Omitted model parameters are:" + ", ".join(omit_keys), file=log_file, flush=True)
            print("--------------------------\n", file=log_file, flush=True)
            return load_keys, episode_num, eval_distance, std_distance, avg_correl
        except Exception:
            traceback.print_exc()
            print("Failed to load checkpoint...\n", file=log_file, flush=True)
            return [], 0, 0, 0, 0

    def compute_values_by_game(self, game_name, sanity_check_msg):
        with torch.no_grad():
            s_a_sequence, r_sequence = self.load_sports_data(game_label=game_name, sanity_check_msg=sanity_check_msg)
            pid_sequence = self.load_player_id(game_label=game_name)
            if self.apply_rnn:
                transition_game = self.build_rnn_transitions(s_a_data=s_a_sequence,
                                                             r_data=r_sequence,
                                                             pid_sequence=pid_sequence)
            else:
                transition_game = self.build_transitions(s_a_data=s_a_sequence,
                                                         r_data=r_sequence,
                                                         pid_sequence=pid_sequence)
            game_transition_data = Transition(*zip(*transition_game[0:]))
            state_actions_all = torch.stack(game_transition_data.state_action).to(self.device)
            if self.apply_rnn:
                trace_mask_all = build_trace_mask(trace=game_transition_data.trace,
                                                  max_trace_length=self.max_trace_length)
            output_all = None
            latent_features_all = None
            # if self.apply_rnn:
            # latent_features_sequence = []
            for j in range(int(len(state_actions_all) / self.batch_size) + 1):
                if self.apply_rnn:
                    state_actions = state_actions_all[j * self.batch_size:(j + 1) * self.batch_size, :, :]
                    trace_mask = trace_mask_all[j * self.batch_size:(j + 1) * self.batch_size, :]
                    if len(state_actions) == 0 or len(trace_mask) == 0:
                        continue
                    if 'distrib' in self.task:
                        trace_mask = np.expand_dims(trace_mask, axis=(2, 3)).repeat(3, axis=2)\
                            .repeat(self.num_tau, axis=3)
                    else:
                        trace_mask = np.expand_dims(trace_mask, axis=(2)).repeat(3, axis=2)
                    hidden_state = None
                    theta_all_step = []
                    latent_features_sequence = []
                    for i in range(self.max_trace_length):
                        if self.apply_resnet:
                            if self.split_state_action_resnet:
                                # tmp = state_actions[:, i]
                                state_output = self.state_resnet(state_actions[:, i, :-len(ICEHOCKEY_ACTIONS)])
                                action_output = self.action_resnet(state_actions[:, i, -len(ICEHOCKEY_ACTIONS):])
                                rnn_input = torch.cat([state_output, action_output], dim=1)
                            else:
                                rnn_input = self.resnet(state_actions[:, i, :])
                        else:
                            rnn_input = state_actions[:, i, :]
                        latent_features_sequence.append(rnn_input)
                        hidden_state = self.rnn(rnn_input, hidden_state)
                        theta = self.online_net(hidden_state)
                        theta_all_step.append(to_np(theta))
                    output_sequence = np.stack(theta_all_step, axis=1)
                    output = np.sum(output_sequence * trace_mask, axis=1)
                    latent_features = torch.stack(latent_features_sequence, dim=1)

                    # output_mean = np.mean(output, -1)
                    # data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
                    # all_features = GAME_FEATURES + ACTIONS
                    # trace = game_transition_data.trace[j * self.batch_size:(j + 1) * self.batch_size]
                    # # actions = []
                    # for k in range(len(state_actions)):
                    #     reverse_data = {}
                    #     state_action_data = state_actions[k][trace[k] - 1]
                    #     for m in range(len(all_features)):
                    #         feature_name = all_features[m]
                    #         feature_standard_value = state_action_data[m]
                    #         feature_value = feature_standard_value * data_stds[feature_name] + data_means[feature_name]
                    #         reverse_data.update({feature_name: feature_value})
                    #     action = None
                    #     max_action_label = 0
                    #     for candidate_action in ACTIONS:
                    #         if reverse_data[candidate_action] > max_action_label:
                    #             max_action_label = reverse_data[candidate_action]
                    #             action = candidate_action
                    #     if action == 'goal':
                    #         print("find u")
                    #         print(output_mean[k])
                    # print(action)

                else:
                    state_actions = state_actions_all[j * self.batch_size:(j + 1) * self.batch_size, :]
                    if self.apply_resnet:
                        state_actions = self.resnet(state_actions)
                    latent_features = state_actions
                    theta = self.online_net(state_actions)
                    output = to_np(theta)
                if output_all is None:
                    output_all = output
                    latent_features_all = latent_features
                else:
                    output_all = np.concatenate([output_all, output], axis=0)
                    latent_features_all = torch.cat([latent_features_all, latent_features], dim=0)

        assert len(latent_features_all) == len(state_actions_all)
        assert len(output_all) == len(state_actions_all)
        self.latent_features = latent_features_all
        return output_all, transition_game

    def compute_uncertainty_by_game(self,
                                    game_name,
                                    sanity_check_msg,
                                    uncertainty_model='gda',
                                    transition_game=None,
                                    # output_values=None,
                                    use_home=None,
                                    ):
        # if self.gda_fitting_target == 'Actions':
        #     uncertainty_model = self.all_gda_models['home']  # in this case, home and away models are the same
        # elif self.gda_fitting_target == 'QValues':
        #     assert is_home is not None
        #     if is_home:
        #         uncertainty_model = self.all_gda_models['home']
        #     else:
        #         uncertainty_model = self.all_gda_models['away']
        # else:
        #     raise ValueError("Incorrect gda target")
        if uncertainty_model == 'gda':
            if transition_game is None:
                _, transition_game = self.compute_values_by_game(game_name, sanity_check_msg)
            latent_features = self.latent_features.cpu().detach().numpy()
            trace_lengths = [transition_game[i].trace for i in range(len(transition_game))]
            all_gda_features = handle_gda_features(fitting_target=self.gda_fitting_target,
                                                   all_latent_features=latent_features,
                                                   all_trace_length=trace_lengths,
                                                   sanity_check_msg=sanity_check_msg,
                                                   max_trace_length=self.max_trace_length,
                                                   split_state_action_resnet=self.split_state_action_resnet,
                                                   apply_history=self.gda_apply_history)
            uncertainty_all = []
            for i in range(len(transition_game)):
                if self.gda_fitting_target == 'Actions':
                    uncertainty_model = self.all_gda_models['home']
                elif self.gda_fitting_target == 'QValues':
                    state_action_data = transition_game[i].state_action[transition_game[i].trace - 1]
                    state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                                data_means=self.data_means,
                                                                data_stds=self.data_stds,
                                                                sanity_check_msg=sanity_check_msg,
                                                                sports=self.sports)
                    if use_home is not None:
                        is_home = use_home
                    else:
                        is_home = 1 if state_action_origin['home'] > state_action_origin['away'] else 0
                    if is_home:
                        uncertainty_model = self.all_gda_models['home']
                    else:
                        uncertainty_model = self.all_gda_models['away']
                else:
                    raise ValueError("Incorrect gda target")
                log_prob_b_y = uncertainty_model.log_prob(torch.tensor(all_gda_features[i, None, :]).unsqueeze(0))
                uncertainty = entropy(logits=log_prob_b_y)
                uncertainty_all.append(uncertainty[0])
            uncertainty_all = torch.stack(uncertainty_all, dim=0).detach().cpu().numpy()
        elif uncertainty_model == 'maf':
            s_a_sequence, r_sequence = self.load_sports_data(game_label=game_name, sanity_check_msg=sanity_check_msg)
            pid_sequence = self.load_player_id(game_label=game_name)
            if self.apply_rnn:
                transition_game = self.build_rnn_transitions(s_a_data=s_a_sequence,
                                                             r_data=r_sequence,
                                                             pid_sequence=pid_sequence)
            else:
                transition_game = self.build_transitions(s_a_data=s_a_sequence,
                                                         r_data=r_sequence,
                                                         pid_sequence=pid_sequence)
            game_transition_data = Transition(*zip(*transition_game[0:]))

            with torch.no_grad():
                if self.maf_cond_value:
                    output_game, _ = self.compute_values_by_game(game_name=game_name,
                                                                  sanity_check_msg=sanity_check_msg)
                    values_cond = np.mean(output_game, axis=2)
                    values_cond = to_pt(values_cond, enable_cuda=self.enable_cuda, type='float')
                else:
                    values_cond = None
                m_loss, log_prob = validate_maf(agent=self,
                                                batch=game_transition_data,
                                                sanity_check_msg=sanity_check_msg,
                                                batch_values_cond=values_cond)
                uncertainty_all = -log_prob
        else:
            raise ValueError("Uknown uncertainty model {0}".format(uncertainty_model))
        # uncertainty_all = uncertainty_all.detach().cpu().numpy()
        return uncertainty_all, transition_game

    def fit_gda(self,
                gda_fitting_target,
                debug_mode,
                sanity_check_msg,
                log_file, ):
        self.gda_fitting_target = gda_fitting_target
        all_files = sorted(os.listdir(self.train_data_path))
        training_files, validation_files, testing_files = \
            divide_dataset_according2date(all_data_files=all_files,
                                          train_rate=self.train_rate,
                                          sports=self.sports,
                                          if_split=self.apply_data_date_div
                                          )

        all_output = []
        all_latent_features = []
        all_trace_lengths = []
        all_transition_game = []
        if debug_mode:
            training_files = training_files[:2]
        # for game_name in tqdm(training_files, desc="Training gda model", file=log_file):
        for game_name in training_files:
            output, transition_game = self.compute_values_by_game(game_name, sanity_check_msg)
            all_output.append(output.mean(axis=2))
            all_transition_game += transition_game
            all_latent_features.append(self.latent_features.cpu().detach().numpy())
            all_trace_lengths += [transition_game[i].trace for i in range(len(transition_game))]
        all_output_cat = np.concatenate(all_output, axis=0)
        all_latent_features_cat = np.concatenate(all_latent_features, axis=0)
        all_gda_features_cat = handle_gda_features(fitting_target=self.gda_fitting_target,
                                                   all_latent_features=all_latent_features_cat,
                                                   all_trace_length=all_trace_lengths,
                                                   sanity_check_msg=sanity_check_msg,
                                                   max_trace_length=self.max_trace_length,
                                                   split_state_action_resnet=self.split_state_action_resnet,
                                                   apply_history=self.gda_apply_history, )

        # data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
        teams = ['home', 'away', 'end']
        all_gda_models = {}
        for i in range(3):
            all_labels = []
            all_actions = set()
            all_output_team = all_output_cat[:, i]
            if self.gda_fitting_target == 'QValues':
                num_classes = 10
                discretizer = QValueDiscretization(all_q_values=all_output_team,
                                                   split_num=num_classes,
                                                   discret_mode=self.gda_discret_mode)
                for output in all_output:
                    labels = discretizer.discretize_q_values(output[:, i])
                    # import matplotlib.pyplot as plt
                    # plt.figure()
                    # plt.plot(range(len(output)), output)
                    # plt.show()
                    all_labels.append(labels)
                all_labels = torch.tensor(np.concatenate(all_labels, axis=0))
                # tmp = all_labels.detach().numpy().tolist()
                class_labels = list(set(all_labels.detach().numpy().tolist()))
            elif self.gda_fitting_target == 'Actions':
                # num_classes = len(ACTIONS)
                action_visited_flag = [False for i in range(len(ICEHOCKEY_ACTIONS))]
                for j in range(len(all_latent_features_cat)):
                    state_action_data = all_transition_game[j].state_action[all_trace_lengths[j] - 1]
                    state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                                data_means=self.data_means,
                                                                data_stds=self.data_stds,
                                                                sanity_check_msg=sanity_check_msg)
                    max_action_label = 0
                    label = None
                    action = None
                    for k in range(len(ICEHOCKEY_ACTIONS)):
                        candidate_action = ICEHOCKEY_ACTIONS[k]
                        if state_action_origin[candidate_action] > max_action_label:
                            max_action_label = state_action_origin[candidate_action]
                            action = candidate_action
                            label = k
                    action_visited_flag[label] = True
                    all_labels.append(label)
                    all_actions.add(action)
                    # print(action)
                all_labels, num_classes = label_visiting_shrink(all_labels, action_visited_flag)
                all_labels = torch.tensor(all_labels)
                class_labels = [i for i in range(num_classes)]
            else:
                raise ValueError("unknown fitting_target {0}".format(self.gda_fitting_target))
            # tmp1 = all_gda_features_cat.detach().cpu().numpy()
            # for tmp in tmp1:
            #     print(tmp)
            # tmp2 = all_labels.detach().cpu().numpy()
            # for tmp in tmp2:
            #     print(tmp)
            # print(all_actions, file=log_file, flush=True)
            gda, _ = gmm_fit(embeddings=all_gda_features_cat,
                             labels=all_labels,
                             class_labels=class_labels,
                             apply_pd=self.gda_apply_pd)
            all_gda_models.update({teams[i]: gda})

            # for latent_features in all_gda_features_cat:
            # log_probs_B_Y = gda.log_prob(all_gda_features_cat[:, None, :])
            # tmp = log_probs_B_Y.detach().cpu().numpy()
            # uncertainty = entropy(logits=log_probs_B_Y)
            # uncertainty = uncertainty.detach().cpu().numpy()
            # print(uncertainty)
            # train_densities = torch.logsumexp(log_probs_B_Y, dim=-1)
            # train_min_density = train_densities.min().item()
            # uncertainty = train_densities - train_min_density
        self.all_gda_models = all_gda_models

    def save_gda(self, gda_model_path):
        with open(gda_model_path, 'wb') as file:
            pickle.dump(self.all_gda_models, file)
        print("Finish saving the gda model {0}".format(gda_model_path), file=self.log_file, flush=True)

    def load_gda(self, gda_model_path):
        with open(gda_model_path, 'rb') as file:
            self.all_gda_models = pickle.load(file)
        print("Finish loading the gda model {0}".format(gda_model_path), file=self.log_file, flush=True)
